package com.carol.bigdata.task.model.train

import com.carol.bigdata.Config
import com.carol.bigdata.task.model.algo.{ModelTrait, ModelUtil}
import com.carol.bigdata.task.model.feature.FeatureUtil
import com.carol.bigdata.utils.{Flag, FuncUtil, TimeUtil}
import org.apache.spark.SparkContext
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.classification.OneVsRest
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

object TrainRetention {
    // 用户画像表
    val uTable = "user_profile"
    // 用户画像表列簇
    val ucfList = List("key", "info", "env_tag", "favor_tag", "pay_tag", "social_tag", "strategy_tag", "adjust_tag")
    // rowKey字段名
    val keyColumn = List("time", "game_id", "uid")
    // 特征列类型
    val numFeatColumnList = List("level") ::: List("level").map(_ + "_total")
    val doubleFeatColumns = List()
    val mapFeatColumn = List("$address", "$device", "$login_time", "$os", "$game")
    val mapFeatColumnList = mapFeatColumn ::: mapFeatColumn.map(_ + "_total")

    // 留存标签表
    val lTable = "retention_label"
    // 留存标签列簇
    val lcfList: List[String] = List("key", "retention")
    // 留存标签列
    val dayRange: List[String] = List("1")
    val lbCols: List[String] = dayRange.map(x => "active_r" + x)
    // 留存标签分类
    val retentionList = List("0", "1")
    val lbList: List[List[String]] = List.fill(lbCols.length)(retentionList) // List( List(0, 1) )
    val uTagFuncMap: Map[String, String => String] = Map("tag" -> FuncUtil.defaultTagFunc)

    // 模型参数
    val modelName = "LR"
    val initParams = Map[String, Any]()
    val objective = "binary"

    def run(hbaseParams: Map[String, String],
            spark: SparkSession,
            game: String): Unit = {
        // 1.获取用户画像RDD (最近30天)
        val pattern: String = s"${game}"
        val userProfileRDD = FeatureUtil.calUserProfileRDD(hbaseParams, spark, uTable, ucfList, keyColumn, numFeatColumnList, mapFeatColumnList, pattern, filterMode = "PATTERN")
        println("userProfileRDD.count():", userProfileRDD.count())
        userProfileRDD.take(5).foreach(println)
        // 2.获取留存标签RDD (最近30天)
        println("lbList:", lbList)
        val labelRDD = FeatureUtil.calLabelRDD(hbaseParams, spark, lTable, lcfList, keyColumn, lbCols, lbList, pattern)
        println("labelRDD.count():", labelRDD.count())
        labelRDD.take(5).foreach(println)
        // 3.拼接特征
        val featureDF: List[DataFrame] = FeatureUtil.calHashFeatureLabelDF(spark, userProfileRDD, labelRDD, keyColumn, doubleFeatColumns, mapFeatColumnList, lbCols, numFeatures = 10)
        featureDF.foreach(df => {
            df.show(5, false)
            // 4.训练
            val data: Array[Dataset[Row]] = df.randomSplit(Array(0.8, 0.2))
            println("train, test:", data(0).count(), data(1).count())
            train(modelName, initParams, objective, data(0), data(1))
        })

    }

    def train(modelName: String,
              initParams: Map[String, Any],
              objective: String,
              trainData: DataFrame,
              testData: DataFrame): Unit = {
        // 交叉验证
        // 超参训练模型
        // 保存模型
        val modelTrait: ModelTrait = ModelUtil.buildModel(modelName)
        modelTrait.init(initParams)
        val numClass = if (objective.contains("binary")) 2 else 3
        val pipeline = modelTrait.buildPipeline(numClass = numClass, objective = objective)
        val crossModel = modelTrait.buildValidator(pipeline, objective = objective).fit(trainData)
        val newModelTrait = ModelUtil.buildModel(modelName)
        //val bestParamsMap = modelTrait.getBestParams(crossModel)
        //newModelTrait.init(initParams ++ bestParamsMap)    // 用bestParams覆盖initParams参数并初始化模型
        newModelTrait.init(initParams)
        newModelTrait.updateTuneParamsFromCV(crossModel, maxBy = true, objective = objective)
        val newPipeline = newModelTrait.buildPipeline(numClass = numClass, objective = objective)
        val newParamsMap = {
            if (objective.contains("binary") || List("lightgbm", "xgboost").contains(modelName.toLowerCase))
                newPipeline.getStages.last.extractParamMap()
            else
                newPipeline.getStages.last.asInstanceOf[OneVsRest].getClassifier.extractParamMap()

        }
        println("newParamsMap after update:", newParamsMap)
        val model: PipelineModel = newPipeline.fit(trainData)
        val savePath = if (objective.contains("binary")) s"binary_${modelName}_model" else s"multi_${modelName}_model"
        newModelTrait.saveModel(model, savePath)
        val results = model.transform(testData)
        results.show(100, truncate = false)
        newModelTrait.evalModel(results)
    }

    def main(args: Array[String]): Unit = {
        Flag.Parse(args)
        val spark: SparkSession = SparkSession.builder()
          .appName("TrainRetention")
          .master("local[*]")
          .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
          .getOrCreate()
        val sc: SparkContext = spark.sparkContext
        sc.setLogLevel("ERROR")
        println("spark:", spark)
        val hbaseParams: Map[String, String] = sc.broadcast(Config.hbaseParams).value
        val game = "5"
        TimeUtil.timer(run(hbaseParams, spark, game))
    }
}
